import tensorflow as tf
import numpy as np
class nist_layer(tf.keras.layers.Layer):

    def __init__(self, units, neuroseed_factor, activation= None, kernel_regularizer = None, **kwargs):

        super().__init__(**kwargs)

        self.units = units

        self.activation = tf.keras.activations.get(activation)

        self.neuroseed_factor = neuroseed_factor

        self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)

    def build(self, batch_input_shape):

        self.indim = batch_input_shape[-1]

        self.outdim = self.units

        if self.neuroseed_factor > self.outdim:

            raise ValueError("Growth state cannot exceed output dimension")

        # Create cyclic band mask

        mask = np.zeros((self.indim, self.outdim), dtype=np.float32)

        for i in range(self.indim):

            for k in range(self.neuroseed_factor):

                j = (i + k) % self.outdim

                mask[i, j] = 1.0

        self.w = mask

        self.n_param = np.count_nonzero(self.w)

        self.n_param_dense = self.w.shape[0] * self.w.shape[1]          # print numbe of weights later

        print("number of Weights(Sparse): {}, \nNumber of Weights(Dense): {}". format(self.n_param, self.n_param_dense))

        self.raw_kernel = self.add_weight(

        name='raw_kernel',

        shape=(batch_input_shape[-1], self.units),

        initializer='random_normal',

        trainable=True

    )

        if self.kernel_regularizer is not None:

            self.add_loss(self.kernel_regularizer(self.raw_kernel))

        self.bias = self.add_weight(name= 'bias', shape= [self.units], initializer= 'random_normal')

        #print("raw_kernel.dtype:", self.raw_kernel.dtype)  # should be float32

        #print("mask dtype:", self.w.dtype)                 # must match raw_kernel

        #print("compute_dtype:", self.compute_dtype)        

    def call(self, X):

        #masked_kernel = self.raw_kernel * tf.cast(self.w, self.compute_dtype)

        masked_kernel = self.raw_kernel * self.w

        masked_kernel = tf.stop_gradient(masked_kernel - self.raw_kernel) + self.raw_kernel  # gradient masking

        return self.activation((X @ masked_kernel) + self.bias)

        #return output

    def get_config(self):

        base_config = super().get_config()

        return {**base_config, 'units':self.units,

                "activation":tf.keras.activations.serialize(self.activation)}
